# coding:utf-8
import os
from Feature import *
from Mask import Mask
from collections import OrderedDict
import json


class FtrLine:
	Fname = 'FileName'
	FtrType = 'FeatureType'
	FtrTypeID = 'FeatureTypeID'
	FtrList = 'Feature'
	MaskMethodID = 'MaskMethodID'


def extractFeatures(imgdir, ftrfile, ftrtype, maskID=None):
	f = file(ftrfile, 'wb')
	ftr = Feature.init_methods[ftrtype](None)
	header_dict = OrderedDict()
	header_dict[FtrLine.FtrType] = ftr.name
	header_dict[FtrLine.FtrTypeID] = ftrtype
	if maskID is not None:
		header_dict[FtrLine.MaskMethodID] = maskID
	else:
		header_dict[FtrLine.MaskMethodID] = None
	jsonstr = json.dumps(obj=header_dict, indent=None, separators=(',', ':'))
	f.writelines(jsonstr + '\n')

	ftr_dict = OrderedDict()
	for root, dirname, filenames in os.walk(imgdir):
		for index, filename in enumerate(filenames):
			ftrs = []
			filepath = os.path.join(root, filename)
			print ('[' + str(index + 1) + ']' + u"正在提取特征：" + filepath)
			img = cv.imread(filepath)
			if img is None:
				continue
			if maskID is not None:
				masks = Mask.mask_method[maskID].maskImages(img.shape[:2])
				ftr_dict[FtrLine.Fname] = filepath
				for maskimg in masks:
					ftr.extract(img, maskimg)
					ftrs.append(ftr.feature)
			else:
				ftr_dict[FtrLine.Fname] = filepath
				ftr.extract(img, None)
				ftrs.append(ftr.feature)
			ftr_dict[FtrLine.FtrList] = ftrs
			jsonstr = json.dumps(obj=ftr_dict, indent=None, separators=(',', ':'))
			f.writelines(jsonstr + '\n')
			f.flush()
	f.close()
	print (u"特征提取完成!!!,写入文件：" + os.path.join(ftrfile))


def retrieve(imgfile, ftrfile, ftrtype, maskID=None):
	img = cv.imread(imgfile)
	if img is None:
		return
	ftrs = []
	weight = (1,)
	mask_count = 1
	if maskID is not None:
		mask = Mask.mask_method[maskID]
		masks = mask.maskImages(img.shape[:2])
		weight = mask.weight
		mask_count = mask.mask_count
		for maskimg in masks:
			ftr = Feature.init_methods[ftrtype](None)
			ftr.extract(img, maskimg)
			ftrs.append(ftr)
	else:
		ftr = Feature.init_methods[ftrtype](None)
		ftr.extract(img, None)
		ftrs.append(ftr)
	f = file(ftrfile, 'rb')
	distances = []
	stored_ftrtypeid = None
	stored_maskID = None
	for line in f:
		dict = json.loads(line)
		try:
			filename = dict[FtrLine.Fname]
			stored_ftrs = dict[FtrLine.FtrList]
		except KeyError:
			try:
				stored_ftrtypeid = dict[FtrLine.FtrTypeID]
				stored_maskID = dict[FtrLine.MaskMethodID]
				continue
			except KeyError:
				print u"特征文件错误!!"
				return
		else:
			if stored_ftrtypeid != ftrtype or stored_maskID != maskID:
				print u"特征文件错误!!"
				return
			distance = 0
			for i in xrange(mask_count):
				tmp_ftr = Feature.init_methods[ftrtype](stored_ftrs[i])
				distance += weight[i] * ftrs[i].compareWith(tmp_ftr)
			distances.append([filename, distance])
	distances.sort(key=lambda x: x[1], reverse=True)
	showResult(distances)
	f.close()
	print (u"检索完成!!!")


def showResult(dists):
	str_join = ''
	for dist in dists:
		print str_join.join((u'文件:', dist[0].split(os.sep)[-1], ' ', u'相似度:', str(round(dist[1], 2)), '%'))
